package mybatis.generator.support.generator;

import mybatis.generator.FilterType;
import mybatis.generator.MybatisGenerator;
import mybatis.generator.annotation.ExecuteInsert;
import mybatis.generator.annotation.GeneratedSql;
import mybatis.generator.ext.InterceptorContext;
import mybatis.generator.support.Reflection;
import mybatis.generator.support.parse.ColumnValuePair;
import mybatis.generator.support.parse.EntityMeta;
import mybatis.generator.support.parse.TableColumn;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.scripting.LanguageDriver;
import org.apache.ibatis.scripting.xmltags.XMLLanguageDriver;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.defaults.DefaultSqlSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.util.*;
import java.util.stream.Collectors;

/**
 * Created by huangdachao on 2018/6/14 16:59.
 */
public class InsertGenerator {
    private static final Logger LOG = LoggerFactory.getLogger(InsertGenerator.class);

    @SuppressWarnings("unchecked")
    public static SqlSource generate(Configuration conf, Class<?> mapperClass, Method method) {
        if (method.getParameterCount() != 1) {
            throw new IllegalArgumentException("插入方法只允许一个实体类型参数或实体类型的集合类型参数："
                + Reflection.formatMethod(mapperClass, method));
        }

        EntityMeta em = EntityMeta.get(Reflection.getEntityClass(mapperClass));
        return parameterObject -> {
            Object arg = Reflection.parseArgs(conf, method, parameterObject)[0];
            if (arg instanceof DefaultSqlSession.StrictMap) {
                arg = ((Map) arg).values().iterator().next();
            }

            InterceptorContext context = new InterceptorContext(mapperClass, em.getEntityClass(), method, new Object[]{arg}, GeneratedSql.Type.INSERT);
            MybatisGenerator.intercept(context);
            List<ColumnValuePair> cvp = new ArrayList<>();
            Class<?> argClass = arg.getClass();
            Class<?> paramClass = argClass;
            String collectionType = null;

            if (Map.class.isAssignableFrom(argClass)) {
                Set<String> fields = ((Map<String, ?>) arg).keySet();
                cvp.addAll(mergeInsertColumns(em, fields));
            } else {
                if (Collection.class.isAssignableFrom(argClass)) {
                    collectionType = "list";
                    ParameterizedType type = (ParameterizedType) Reflection.resolveType(method.getGenericParameterTypes()[0],
                        mapperClass, method.getDeclaringClass());
                    argClass = (Class<?>) type.getActualTypeArguments()[0];
                } else if (argClass.getComponentType() != null) {
                    collectionType = "array";
                    argClass = argClass.getComponentType();
                }

                if (em.getEntityClass() != argClass) {
                    Collection fields = Reflection.resolveFields(argClass, false).stream().map(Field::getName).collect(Collectors.toList());
                    cvp.addAll(mergeInsertColumns(em, fields));
                } else {
                    em.getInsertColumns().forEach(tc -> cvp.add(new ColumnValuePair(tc.column, em.getFieldMeta().get(tc).getField().getName())));
                }
            }

            ColumnValuePair.mergeFromColumnValues(method.getAnnotation(ExecuteInsert.class).columnValue(), em, cvp);
            Map<String, Object> additionalParams = new HashMap<>();
            ColumnValuePair.mergeFromInterceptorContext(context, additionalParams, cvp);
            if (cvp.isEmpty()) {
                throw new IllegalArgumentException("没有有效的可插入字段" + Reflection.formatMethod(mapperClass, method));
            }

            List<String> columns = new ArrayList<>(cvp.size());
            List<String> values = new ArrayList<>(cvp.size());
            for (ColumnValuePair p : cvp) {
                columns.add(p.getColumn());
                if (FilterType.FIELD_PATTERN.matcher(p.getPlaceholder()).find()) {
                    if (collectionType != null) {
                        values.add("#{i." + p.getPlaceholder() + "}");
                    } else {
                        values.add("#{" + p.getPlaceholder() + "}");
                    }
                } else {
                    values.add(p.getPlaceholder());
                }
            }

            StringBuilder sb = new StringBuilder();
            if (collectionType != null) {
                sb.append("<script>insert into ").append(em.getTable().table).append("(")
                    .append(columns.stream().collect(Collectors.joining(",")))
                    .append(") values <foreach item='i' collection='")
                    .append(collectionType)
                    .append("' separator=','>(")
                    .append(values.stream().collect(Collectors.joining(",")))
                    .append(")</foreach></script>");
            } else {
                sb.append("insert into ").append(em.getTable().table).append("(")
                    .append(columns.stream().collect(Collectors.joining(",")))
                    .append(") values (")
                    .append(values.stream().collect(Collectors.joining(",")))
                    .append(")");
            }

            String sql = sb.toString();
            LOG.info(sql);
            LanguageDriver languageDriver = conf.getLanguageRegistry().getDriver(XMLLanguageDriver.class);
            SqlSource sqlSource = languageDriver.createSqlSource(conf, sql, paramClass);
            BoundSql bsql = sqlSource.getBoundSql(parameterObject);
            additionalParams.forEach(bsql::setAdditionalParameter);
            return bsql;
        };
    }

    private static List<ColumnValuePair> mergeInsertColumns(EntityMeta em, Collection<String> fields) {
        Set<TableColumn> insertColumns = em.getInsertColumns();
        List<ColumnValuePair> pairs = new ArrayList<>();
        fields.forEach(f -> {
            TableColumn tc = em.getTableColumn(f);
            if (insertColumns.contains(tc)) {
                pairs.add(new ColumnValuePair(tc.column, f));
            }
        });
        return pairs;
    }
}
